/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved.
 * Description: The description of class OpIRFacade
 */

#include "framework/graph/core/infershape/op_ir_facade.h"
#include "framework/graph/core/infershape/op_ir_ctx.h"
#include "framework/graph/core/node/node_sub_graph.h"
#include "framework/graph/utils/op_desc_utils.h"
#include "framework/graph/utils/tensor_utils.h"
#include "framework/graph/utils/attr_utils.h"
#include "framework/graph/debug/ge_log.h"
#include "graph/op/const_defs.h"

using namespace std;

namespace ge {
OpIRFacade::OpIRFacade(Node& node) : node_(&node) {}

size_t OpIRFacade::GetInputsSize() const
{
    return GetOpDesc().GetInputsSize();
}

bool OpIRFacade::OptionalInputIsSet(uint32_t index) const
{
    return GetOpDesc().InputIsSet(index);
}

Shape OpIRFacade::GetInputShape(uint32_t index) const
{
    return GetOpDesc().GetInputDesc(index).GetShape();
}

DataType OpIRFacade::GetInputDataType(uint32_t index) const
{
    return GetOpDesc().GetInputDesc(index).GetDataType();
}

ComputeGraphPtr OpIRFacade::GetSubGraph(const string& subGraphName) const
{
    return node_->ROLE(NodeSubGraph).FindSubGraphPtr(subGraphName);
}

Shape OpIRFacade::GetConstInputShape(uint32_t index) const
{
    ConstTensorPtr tensor = GetConstInputTensor(index);
    if (tensor == nullptr) {
        return Shape();
    }
    return tensor->GetTensorDesc().GetShape();
}

ConstTensorPtr OpIRFacade::GetConstInputTensor(uint32_t index) const
{
    if (!node_->ROLE(NodeSpec).IsConstOp()) {
        if ((index >= GetInputsSize()) || OpDescUtils::IsNonConstInput(*this->node_, index)) {
            return nullptr;
        }
    }
    size_t weigthIndex = 0;
    for (size_t i = 0; i < index; i++) {
        if (!OpDescUtils::IsNonConstInput(*this->node_, i)) {
            weigthIndex++;
        }
    }
    vector<ConstTensorPtr> weightVec = OpDescUtils::GetWeights(*this->node_);
    if (weigthIndex < weightVec.size()) {
        return weightVec.at(weigthIndex);
    } else {
        return nullptr;
    }
}

void OpIRFacade::SetOutput(uint32_t index, const Shape& shape, const DataType dataType) const
{
    TensorDesc outputDesc = GetOpDesc().GetOutputDesc(index);
    outputDesc.SetShape(shape);
    outputDesc.SetDataType(dataType);
    GetOpDesc().UpdateOutputDesc(index, outputDesc);
}

void OpIRFacade::SetOutput(uint32_t index, const Shape& shape, const DataType dataType, const string& handle) const
{
    TensorDesc outputDesc = GetOpDesc().GetOutputDesc(index);
    outputDesc.SetShape(shape);
    outputDesc.SetDataType(dataType);
    TensorUtils::SetTensorArrayHandle(outputDesc, handle); // dataflowhandle
    GetOpDesc().UpdateOutputDesc(index, outputDesc);
}

ComputeGraphPtr OpIRFacade::GetSubGraph(const std::string& subGraphAttr, ge::ComputeGraphPtr& subGraph) const
{
    // 找到子图，拿到子图的Data节点
    std::string subGraphName;
    ge::OpDesc& desc = GetOpDesc();
    GE_CHK_BOOL_EXEC(ge::AttrUtils::GetStr(desc, subGraphAttr, subGraphName), return nullptr,
        "failed to get %s node subgraph name", node_->ROLE(NodeSpec).Name().c_str());

    subGraph = GetSubGraph(subGraphName);
    return subGraph;
}

GraphErrCodeStatus OpIRFacade::GetInputs(std::vector<ge::TensorDesc>& inputs) const
{
    auto& opDesc = GetOpDesc();
    auto allInputDesc = opDesc.GetAllInputsDesc();
    inputs.clear();
    inputs.insert(inputs.begin(), allInputDesc.begin(), allInputDesc.end());
    return GRAPH_SUCCESS;
}

size_t OpIRFacade::GetOutputsSize() const
{
    return GetOpDesc().GetOutputsSize();
}
} // namespace ge
